{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Classifying 3D shapes\n",
    "\n",
    "Techniques for analyzing 3D shapes are becoming increasingly important due to the vast number of sensors such as LiDAR that are capturing 3D data, as well as numerous computer graphics applications.  These raw data are typically collected in the form of a _point cloud_, which corresponds to a set of 3D points $\\{p_i | i = 1, \\ldots, n\\}$, where each point $p_i$ is a vector of its $(x, y, z)$ coordinates plus extra\n",
    "feature channels such as color, intensity etc. Typically, Euclidean distance is used to calculate the distances between any two points.\n",
    "\n",
    "By finding suitable representations of these point clouds, machine learning can be used to solve a variety of tasks such as those shown in the figure below.\n",
    "\n",
    "![3d-tasks](images/3d_tasks.png)\n",
    "<div style=\"text-align: left\">\n",
    "   <p style=\"text-align: left;\"> <b>Figure reference:</b> adapted from <a href=\"https://arxiv.org/abs/1612.00593\">arxiv.org/abs/1612.00593</a>. </p>\n",
    "</div>\n",
    "\n",
    "This notebook examines how ``giotto-tda`` can be used to extract topological features from point cloud data and fed to a simple classifier to distinguish 3D shapes.\n",
    "\n",
    "If you are looking at a static version of this notebook and would like to run its contents, head over to [GitHub](https://github.com/giotto-ai/giotto-tda/blob/master/examples/classifying_shapes.ipynb) and download the source.\n",
    "\n",
    "**License: AGPLv3**"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Generate simple shapes\n",
    "\n",
    "To get started, let's generate a synthetic dataset of 10 noisy circles, spheres, and tori, where the effect of noise is to displace the points that sample the surfaces by a random amount in a random direction:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from data.generate_datasets import make_point_clouds\n",
    "\n",
    "point_clouds_basic, labels_basic = make_point_clouds(n_samples_per_shape=10, n_points=20, noise=0.5)\n",
    "point_clouds_basic.shape, labels_basic.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here the labels are defined to that a circle is 0, a sphere is 1, and a torus is 2. Each point cloud corresponds to a sampling of the continuous surface – in this case 400 points are sampled per shape. As a sanity check, let's visualise these points clouds using ``giotto-tda``'s plotting API:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from gtda.plotting import plot_point_cloud\n",
    "\n",
    "plot_point_cloud(point_clouds_basic[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_point_cloud(point_clouds_basic[10])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_point_cloud(point_clouds_basic[-1])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## From data to persistence diagrams\n",
    "In raw form, point cloud data is not well suited for most machine learning algorithms because we ultimately need a _feature matrix_ $X$ where each row corresponds to a single sample (i.e. point cloud) and each column corresponds to a particular feature. In our example, each point cloud corresponds to a _collection_ of 3-dimensional vectors, so we need some way of representing this information in terms of a _single_ vector $x^{(i)}$ of feature values.\n",
    "\n",
    "In this notebook we will use persistent homology to generate a topological summary of a point cloud in the form of a so-called persistence diagram. Perhaps the simplest way to understand persistent homology is in terms of _growing balls around each point._ The basic idea is that by keeping track of when the balls intersect we can quantify when topological features like connected components and loops are \"born\" or \"die\".\n",
    "\n",
    "For example, consider two noisy clusters and track their connectivity or \"0-dimensional persistent homology\" as we increase the radius of the balls around each point:\n",
    "\n",
    "![phdim0](images/persistent_homology_0d.gif)\n",
    "<div style=\"text-align: left\">\n",
    "   <p style=\"text-align: left;\"> <b>Figure reference:</b> <a href=\"https://towardsdatascience.com/persistent-homology-with-examples-1974d4b9c3d0\">towardsdatascience.com/persistent-homology-with-examples-1974d4b9c3d0</a>. </p>\n",
    "</div>\n",
    "\n",
    "As the ball radius is grown from 0 to infinity, 0-dimensional persistent homology records when the ball in one connected component first intersects a ball of a _different_ connected component (denoted by a different colour in the animation).  At radius 0, a connected component for each point is _born_ and once any two balls touch we have a _death_ of a connected component with one color persisting and the other color vanishing. The vanishing of a color corresponds to a death, and therefore another point being added to the persistence diagram."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Behind the scenes, this process generating a persistence diagram from data involves several steps:"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1. Construct a simplicial complex\n",
    "\n",
    "The effect of connecting points as we increase some radius $\\epsilon$ results in the creation of geometric objects called simplices.  \n",
    "\n",
    "![vr-complex](images/vr_complex.png)\n",
    "<div style=\"text-align: left\">\n",
    "   <p style=\"text-align: left;\"> <b>Figure reference:</b> <a href=\"https://bit.ly/2z9yP1d\">bit.ly/2z9yP1d</a>. </p>\n",
    "</div>\n",
    "\n",
    "A $k$-simplex $[p_0,\\ldots,p_k]$ is the convex hull of a set of $k+1$ affinely independent points in $\\mathbb{R}^n$. For example the 0-simplex is a point, the 1-simplex is a line, the 2-simplex is a triangular disc, and a 3-simplex a regular tetrahedron:\n",
    "\n",
    "![simplex](images/simplex.png)\n",
    "<div style=\"text-align: left\">\n",
    "   <p style=\"text-align: left;\"> <b>Figure reference:</b> <a href=\"https://bit.ly/2X8AsUX\">bit.ly/2X8AsUX</a>. </p>\n",
    "</div>\n",
    "\n",
    "Linking simplices together results in a simplicial complex $X$, which for computational efficiency is often represented abstractly as finite subsets of the vertex set of $X$. For example, the 2-complex shown below can be written as the set \n",
    "\n",
    "$$X=\\{a,b,c,d,e,[a,b],[b,c],[c,d],[d,e],[e,a],[b,e],[a,b,e]\\}\\,.$$\n",
    "\n",
    "![2-complex](images/2_complex.png)\n",
    "<div style=\"text-align: left\">\n",
    "   <p style=\"text-align: left;\"> <b>Figure reference:</b> <a href=\"https://arxiv.org/abs/1904.11044\">arxiv.org/abs/1904.11044</a>. </p>\n",
    "</div>\n",
    "\n",
    "One of the most common simplicial complexes one encounters in topological data analysis is the [Vietoris-Rips complex](https://en.wikipedia.org/wiki/Vietoris%E2%80%93Rips_complex), which is defined as the set of simplices $[p_0,\\ldots,p_k]$ such that the distance metric $d(p_i,p_j) \\leq \\epsilon$ for all $i,j$.\n",
    "\n",
    "#### 2. Compute Betti numbers\n",
    "\n",
    "Once the simplicial complex is constructed, we can ask questions about its topology. In particular, we can identify the presence of topological invariants such as connected pieces, holes and voids. This is achieved by computing quantities known as _Betti numbers_ which are informally [defined](https://en.wikipedia.org/wiki/Betti_number) as follows:\n",
    "\n",
    "> The $k$th Betti number refers to the number of $k$-dimensional holes on a topological surface. The first few Betti numbers have the following definitions for 0-dimensional, 1-dimensional, and 2-dimensional simplicial complexes:\n",
    "> \n",
    "> * $b_0$ is the number of connected components,\n",
    "> * $b_1$ is the number of one-dimensional or \"circular\" holes,\n",
    "> * $b_2$ is the number of two-dimensional \"voids\" or \"cavities\".\n",
    "\n",
    "By computing Betti numbers of a range of scale values $\\epsilon$, we can track which topological features _persist_ over this range. We can represent these changes in topology (technically homology) in terms of a persistence diagram, where each point corresponds to (birth, death) pairs, and points which are furthest away from the birth = death line correspond to the most persistent features.\n",
    "\n",
    "> Note: the reason we are talking about homology is because the $n$th Betti number represents the rank of the $n$th homology group $H_n$.\n",
    "\n",
    "An example showing the birth and death of \"loops\" (technically the homology group $H_1$) is shown below.\n",
    "\n",
    "![ph1D](images/persistent_homology_1d.gif)\n",
    "<div style=\"text-align: left\">\n",
    "   <p style=\"text-align: left;\"> <b>Figure reference:</b> <a href=\"https://towardsdatascience.com/persistent-homology-with-examples-1974d4b9c3d0\">towardsdatascience.com/persistent-homology-with-examples-1974d4b9c3d0</a>. </p>\n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's descend from abstraction and apply these concepts to our shape classification problem! In ``giotto-tda`` we can derive persistence diagrams from data by selecting the desired transformer in [gtda.homology](https://giotto-ai.github.io/gtda-docs/latest/modules/homology.html) and instantiating the class just like a ``scikit-learn`` estimator. Once the transformer is instantiated, we can make use of the fit-transform paradigm to generate the diagrams:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from gtda.homology import VietorisRipsPersistence\n",
    "\n",
    "# Track connected components, loops, and voids\n",
    "homology_dimensions = [0, 1, 2]\n",
    "\n",
    "# Collapse edges to speed up H2 persistence calculation!\n",
    "persistence = VietorisRipsPersistence(\n",
    "    metric=\"euclidean\",\n",
    "    homology_dimensions=homology_dimensions,\n",
    "    n_jobs=6,\n",
    "    collapse_edges=True,\n",
    ")\n",
    "\n",
    "diagrams_basic = persistence.fit_transform(point_clouds_basic)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Tip: in our example each point cloud has the same number of points so can be fed as single NumPy array. If you have varying number of points per point cloud, you can pass a list of arrays instead."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Once we have computed our persistence diagrams, we can compare how the circle, sphere and torus produce different patterns at each homology dimension:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from gtda.plotting import plot_diagram\n",
    "\n",
    "# Circle\n",
    "plot_diagram(diagrams_basic[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Sphere\n",
    "plot_diagram(diagrams_basic[10])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Torus\n",
    "plot_diagram(diagrams_basic[-1])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "From the persistence diagrams we can see which persistent generators $H_{1,2,3}$ are most persistent for each shape. In particular, we see that:\n",
    "\n",
    "* the circle has one persistent generator $H_1$ corresponding to a loop,\n",
    "* the sphere has one persistent generator $H_2$ corresponding to a void,\n",
    "* the torus has three persistent generators, two for $H_1$ and one for $H_2$."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Although persistence diagrams are useful descriptors of the data, they cannot be used directly for machine learning applications. This is because different persistence diagrams may have different numbers of points, and basic operations like the addition and multiplication of diagrams are not well-defined.\n",
    "\n",
    "To overcome these limitations, a variety of proposals have been made to \"vectorize\" persistence diagrams via embeddings or kernels which are well-suited for machine learning. In ``giotto-tda``, we provide access to the most common vectorizations via the ``gtda.diagrams`` module.\n",
    "\n",
    "For example, one such feature is known as [persistence entropy](https://giotto-ai.github.io/gtda-docs/latest/theory/glossary.html#persistence-entropy) which measures the entropy of points in a diagram $D = \\{(b_i, d_i)\\}_{i\\in I}$ according to\n",
    "\n",
    "$$ E(D) = - \\sum_{i\\in I} p_i \\log p_i$$\n",
    "\n",
    "where $p_i =(d_i - b_i)/L_D$ and $L_D = \\sum_i (d_i - b_i)$. As we did for the persistence diagram calculation, we can use a transformer to calculate the persistent entropy associated with each homology generator $H_{0,1,2}$:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from gtda.diagrams import PersistenceEntropy\n",
    "\n",
    "persistence_entropy = PersistenceEntropy()\n",
    "\n",
    "# calculate topological feature matrix\n",
    "X_basic = persistence_entropy.fit_transform(diagrams_basic)\n",
    "\n",
    "# expect shape - (n_point_clouds, n_homology_dims)\n",
    "X_basic.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Nice - we have found a way to represent each point cloud in terms of just three numbers! Note that this does not depend on the number of points in the original point cloud, although calculating the $H_2$ generators becomes time consuming for point clouds of $O(10^3)$ points (at least on a standard laptop). \n",
    "\n",
    "By visualising the feature matrix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_point_cloud(X_basic)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "we see that there are three distinct clusters, suggesting that a classifier should have no trouble finding a clean decision boundary!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train a classifier\n",
    "\n",
    "With our topological features at hand, the last step is to train a classifier. Since our dataset is small, we will use a Random Forest and make use of the OOB score to simulate accuracy on a validation set:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.ensemble import RandomForestClassifier\n",
    "\n",
    "rf = RandomForestClassifier(oob_score=True)\n",
    "rf.fit(X_basic, labels_basic)\n",
    "\n",
    "print(f\"OOB score: {rf.oob_score_:.3f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Unsurprisingly, our classifier has perfectly separated the 3 classes. Next let's try to combine all the steps as a single pipeline!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Putting it all together\n",
    "\n",
    "Evidently, there are several data transformation steps that need to be executed in the right order to go from point clouds to predictions. Fortunately ``giotto-tda`` provides a ``Pipeline`` class to collect such sequences of transformations. Below is a small pipeline that reproduces all our steps:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from gtda.pipeline import Pipeline\n",
    "\n",
    "steps = [\n",
    "    (\"persistence\", VietorisRipsPersistence(metric=\"euclidean\", homology_dimensions=homology_dimensions, n_jobs=6)),\n",
    "    (\"entropy\", PersistenceEntropy()),\n",
    "    (\"model\", RandomForestClassifier(oob_score=True)),\n",
    "]\n",
    "\n",
    "pipeline = Pipeline(steps)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "By calling the ``fit`` method, the pipeline calls ``fit_transform`` on all transformers before calling ``fit`` on the final estimator:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pipeline.fit(point_clouds_basic, labels_basic)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can then access the Random Forest model by its ``model`` key to pick out the OOB score:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pipeline[\"model\"].oob_score_"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## A more realistic example"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In the above example, the shapes were sufficiently distinct that it was easy to classify them according to their topological features. Here we consider a slightly more realistic example using point clouds from a variety of real-world objects. We will use the 3D dataset from Princeton's COS 429 Computer Vision [course](https://www.cs.princeton.edu/courses/archive/fall09/cos429/assignment3.html). The dataset consists of 200 models organised into 20 classes of 10 objects each. We'll use a subset consisting of a human, vase, dining_chair, and biplane:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from openml.datasets.functions import get_dataset\n",
    "\n",
    "df = get_dataset('shapes').get_data(dataset_format='dataframe')[0]\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As a sanity check, let's filter our ``pandas.DataFrame`` for a single member of a class, e.g. a biplane:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_point_cloud(df.query('target == \"biplane0\"')[[\"x\", \"y\", \"z\"]].values)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, let's collect all these point clouds in a single NumPy array:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "point_clouds = np.asarray(\n",
    "    [\n",
    "        df.query(\"target == @shape\")[[\"x\", \"y\", \"z\"]].values\n",
    "        for shape in df[\"target\"].unique()\n",
    "    ]\n",
    ")\n",
    "point_clouds.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As we did with the simple shapes, we can calculate persistence diagrams for each of these point clouds:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "persistence = VietorisRipsPersistence(\n",
    "    metric=\"euclidean\",\n",
    "    homology_dimensions=homology_dimensions,\n",
    "    n_jobs=6,\n",
    "    collapse_edges=True,\n",
    ")\n",
    "persistence_diagrams = persistence.fit_transform(point_clouds)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "By zooming in on the resulting diagrams, we can spot some persistent generators, but with far less signal than before:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Index - (human_arms_out, 0), (vase, 10), (dining_chair, 20), (biplane, 30)\n",
    "index = 30\n",
    "plot_diagram(persistence_diagrams[index])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next we convert each diagram into a 3-dimensional vector using persistent entropy and plot the resulting feature matrix:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "persistence_entropy = PersistenceEntropy(normalize=True)\n",
    "# Calculate topological feature matrix\n",
    "X = persistence_entropy.fit_transform(persistence_diagrams)\n",
    "# Visualise feature matrix\n",
    "plot_point_cloud(X)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Unlike our simple shapes example, we do not see distinct clusters so we expect our classifier performance to be less than perfect in this case. Before we can train a model, we need to define a target vector for each point cloud. A crude, but simple way is to label each class with an integer starting from 0 to $n_\\mathrm{classes} - 1$:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "labels = np.zeros(40)\n",
    "labels[10:20] = 1\n",
    "labels[20:30] = 2\n",
    "labels[30:] = 3\n",
    "\n",
    "rf = RandomForestClassifier(oob_score=True, random_state=42)\n",
    "rf.fit(X, labels)\n",
    "rf.oob_score_"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Improving our model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In the analysis above we used persistence entropy to generate 3 topological features per persistence diagram (one per homology dimension). However, this choice of vectorisation is by no means unique, and in practice one can augment this information with other topological features to produce better models. \n",
    "\n",
    "For example, a simple feature to include is the number of off-diagonal points per homology dimension:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from gtda.diagrams import NumberOfPoints\n",
    "\n",
    "# Reshape single diagram to (n_samples, n_features, 3) format\n",
    "diagram = persistence_diagrams[0][None, :, :]\n",
    "# Get number of points for (H0, H1, H2)\n",
    "NumberOfPoints().fit_transform(diagram)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "A more sophisticated feature is to calculate a vector of _amplitudes_ for each persistence diagram. Here the main idea is to partition a diagram into subdiagrams (one per homology dimension) and use a _metric_ to calculate the amplitude of each subdiagram relative to the trivial diagram (i.e. the one with points on the main diagonal). The result is a vector $\\boldsymbol{a} = (a_{q_1}, \\ldots, a_{q_n})$, where the $q_i$ range over the specified homology dimensions.\n",
    "\n",
    "For example, we can calculate the [Wasserstein amplitudes](https://giotto-ai.github.io/gtda-docs/latest/theory/glossary.html#wasserstein-and-bottleneck-distance) for a single diagram as follows:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from gtda.diagrams import Amplitude\n",
    "\n",
    "Amplitude(metric='wasserstein').fit_transform(diagram)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "So now that we know how to generate new topological features, let's combine them using ``scikit-learn``'s utility function for feature unions:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.pipeline import make_union\n",
    "\n",
    "# Select a variety of metrics to calculate amplitudes\n",
    "metrics = [\n",
    "    {\"metric\": metric}\n",
    "    for metric in [\"bottleneck\", \"wasserstein\", \"landscape\", \"persistence_image\"]\n",
    "]\n",
    "\n",
    "# Concatenate to generate 3 + 3 + (4 x 3) = 18 topological features\n",
    "feature_union = make_union(\n",
    "    PersistenceEntropy(normalize=True),\n",
    "    NumberOfPoints(n_jobs=-1),\n",
    "    *[Amplitude(**metric, n_jobs=-1) for metric in metrics]\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The final step is to combine our feature extraction step with a classifier, fit the pipeline, and extract the OOB score:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pipe = Pipeline(\n",
    "    [\n",
    "        (\"features\", feature_union),\n",
    "        (\"rf\", RandomForestClassifier(oob_score=True, random_state=42)),\n",
    "    ]\n",
    ")\n",
    "pipe.fit(persistence_diagrams, labels)\n",
    "pipe[\"rf\"].oob_score_"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "By creating additional features, we've managed to improve our baseline model by about 30% – not bad!"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
